{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import argparse\n",
    "import copy\n",
    "import json\n",
    "import os\n",
    "import pickle\n",
    "import random\n",
    "import re\n",
    "import sys\n",
    "from collections import defaultdict, deque\n",
    "from typing import List\n",
    "import re\n",
    "from typing import Any, List\n",
    "import ipydagred3\n",
    "import networkx as nx\n",
    "import numpy as np\n",
    "import openai\n",
    "from tqdm import tqdm\n",
    "from csqa_utils import *\n",
    "\n",
    "#os.environ[\"http_proxy\"] = \"http://localhost:7890\"\n",
    "#os.environ[\"https_proxy\"] = \"http://localhost:7890\"\n",
    "setOpenAi(keyid = 4)\n",
    "\n",
    "#GPT_MODEL = \"gpt-4-turbo\"\n",
    "#GPT_MODEL = \"gpt-4o-mini\"\n",
    "GPT_MODEL = \"gpt-3.5-turbo\"\n",
    "\n",
    "# 读取问题数据文件\n",
    "dataset = []\n",
    "with open('train_rand_split.jsonl', 'r') as f:\n",
    "    for line in f:\n",
    "        data = json.loads(line)\n",
    "        dataset.append(data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = \"gpt-4-turbo-preview\"\n",
    "messages = [{\n",
    "    'role': 'user',\n",
    "    'content':'What is the capital of France?'\n",
    "}]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "response = openai.ChatCompletion.create(\n",
    "            model = model,\n",
    "            messages = messages,\n",
    "            temperature = 1\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<OpenAIObject chat.completion id=chatcmpl-A0xmPvxE9EuxhXEOy7HdPIN14iJ0l at 0x117879130> JSON: {\n",
       "  \"id\": \"chatcmpl-A0xmPvxE9EuxhXEOy7HdPIN14iJ0l\",\n",
       "  \"object\": \"chat.completion\",\n",
       "  \"created\": 1724793157,\n",
       "  \"model\": \"gpt-4-0125-preview\",\n",
       "  \"choices\": [\n",
       "    {\n",
       "      \"index\": 0,\n",
       "      \"message\": {\n",
       "        \"role\": \"assistant\",\n",
       "        \"content\": \"The capital of France is Paris.\",\n",
       "        \"refusal\": null\n",
       "      },\n",
       "      \"logprobs\": null,\n",
       "      \"finish_reason\": \"stop\"\n",
       "    }\n",
       "  ],\n",
       "  \"usage\": {\n",
       "    \"prompt_tokens\": 14,\n",
       "    \"completion_tokens\": 7,\n",
       "    \"total_tokens\": 21\n",
       "  },\n",
       "  \"system_fingerprint\": null\n",
       "}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "response"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'answerKey': 'B',\n",
       " 'id': '9b1b25ed5e3f1186dd019e17169a15eb',\n",
       " 'question': {'question_concept': 'dogs',\n",
       "  'choices': [{'label': 'A', 'text': 'four legs'},\n",
       "   {'label': 'B', 'text': 'males'},\n",
       "   {'label': 'C', 'text': 'electrical circuit'},\n",
       "   {'label': 'D', 'text': 'pet'},\n",
       "   {'label': 'E', 'text': 'sniff'}],\n",
       "  'stem': \"When they say dogs are a man's best friend, they aren't talking about just whats?\"}}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[1444]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "problem_id = 88\n",
    "\n",
    "question = dataset[problem_id][\"question\"][\"stem\"]\n",
    "choices = dataset[problem_id][\"question\"][\"choices\"]\n",
    "gold_answer = dataset[problem_id][\"answerKey\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Where can meat last a long time?\n",
      "[{'label': 'A', 'text': 'backery'}, {'label': 'B', 'text': 'ham sandwich'}, {'label': 'C', 'text': 'fridge'}, {'label': 'D', 'text': 'butcher shop'}, {'label': 'E', 'text': 'freezer'}]\n"
     ]
    }
   ],
   "source": [
    "print(question)\n",
    "print(choices)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "任务分解 "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Understand the Meaning of the Question', 'Identify the Options, review each of the given options', 'Analyze and evaluate each option based on the question', 'Select the Best Option']\n",
      "{0: 'Understand the Meaning of the Question', 1: 'Identify the Options, review each of the given options', 2: 'Analyze and evaluate each option based on the question', 3: 'Select the Best Option'}\n"
     ]
    }
   ],
   "source": [
    "def decompose_sql(question, choices):   \n",
    "\n",
    "    prompt_for_decompose = \"\"\"\n",
    "\n",
    "You will be given a common sense question, and a set of choices of answers. \n",
    "Please help me decompose this problem into a series of step-by-step sub-problems.\n",
    "\n",
    "1 examples are as follows:\n",
    "Question: When they say dogs are a man's best friend, they aren't talking about just whats?\n",
    "Choices: \n",
    "[{{'label': 'A', 'text': 'four legs'}},\n",
    "   {{'label': 'B', 'text': 'males'}},\n",
    "   {{'label': 'C', 'text': 'electrical circuit'}},\n",
    "   {{'label': 'D', 'text': 'pet'}},\n",
    "   {{'label': 'E', 'text': 'sniff'}}]\n",
    "   \n",
    "Steps:\n",
    "1. Understand the Meaning of the Question\n",
    "2. Identify the Options, review each of the given options\n",
    "3. Analyze and evaluate each option based on the question\n",
    "4. Select the Best Option\n",
    "\n",
    "Now the question is {},\n",
    "the choices are: {},\n",
    "\n",
    "please decompose it into a series of easy-to-solve steps like the examples.\n",
    "Answer Format: (Please write each broken-down question step on a separate line, starting with a number.)\n",
    "To solve the question \"xxx\", we need to solve the following problems step by step:\n",
    "1. sub-question 1\n",
    "2. sub-question 2\n",
    "3. sub-question 3\n",
    "...\n",
    "\"\"\".format(question, choices)\n",
    "\n",
    "    Q = {\n",
    "        \"role\": \"user\",\n",
    "        \"content\": prompt_for_decompose\n",
    "    }\n",
    "    # Query = Example+[Q]\n",
    "    Query = [Q]\n",
    "    result = askChatGPT(Query, model=GPT_MODEL, temperature=1)\n",
    "    return result\n",
    "\n",
    "\n",
    "def convert_steps_to_format(decom_commands):\n",
    "    # 截取“we need to know:”后的内容\n",
    "    start_index = decom_commands.find(\"we need to solve the following problems step by step:\") + len(\"we need to solve the following problems step by step:\")\n",
    "    subtasks_text = decom_commands[start_index:].strip()\n",
    "    # 将每个子任务单独列出\n",
    "    subtasks = subtasks_text.split('\\n')\n",
    "    subtasks = [task.strip().split('. ', 1)[-1] for task in subtasks]\n",
    "    steps_dict = {index: value for index, value in enumerate(subtasks)}\n",
    "    return subtasks, steps_dict\n",
    "\n",
    "decompose_steps = decompose_sql(question, choices)\n",
    "steps, steps_dict = convert_steps_to_format(decompose_steps)\n",
    "num_steps = len(steps)\n",
    "print(steps)\n",
    "print(steps_dict)  # 只是加了一个问题的编号而已."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4\n"
     ]
    }
   ],
   "source": [
    "print(len(steps))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2\n",
      "(20, ['The question is asking about where meat can last a long time.', 'Sub-problem: Identify the Options, review each of the given options\\n\\n- Option A: bakery\\n- Option B: ham sandwich\\n- Option C: fridge\\n- Option D: butcher shop\\n- Option E: freezer\\n\\nBased on the information provided in the question, the best option for where meat can last a long time would be E: freezer.', 'Based on the question provided, where meat can last a long time is in the freezer (Option E). The low temperatures in the freezer help to preserve the meat and extend its shelf life compared to other options such as the fridge, bakery, ham sandwich, or butcher shop.', 'Sub-problem: Select the Best Option\\nAnswer: The best option for where meat can last a long time is E: freezer.'])\n",
      "(20, ['The question is asking about where meat can last a long time.', 'Sub-problem: Identify the Options, review each of the given options\\n\\n- Option A: bakery\\n- Option B: ham sandwich\\n- Option C: fridge\\n- Option D: butcher shop\\n- Option E: freezer\\n\\nBased on the information provided in the question, the best option for where meat can last a long time would be E: freezer.', 'Based on the question provided, where meat can last a long time is in the freezer (Option E). The low temperatures in the freezer help to preserve the meat and extend its shelf life compared to other options such as the fridge, bakery, ham sandwich, or butcher shop.', 'Sub-problem-3: Select the Best Option\\nThe best option where meat can last a long time is Option E: freezer. The low temperatures in the freezer help preserve the meat and extend its shelf life.'])\n"
     ]
    }
   ],
   "source": [
    "# TOT 求解\n",
    "N = 3  # 每个子问题进行N次proposal\n",
    "M = 2  # 通过评估选出M个最好的proposal\n",
    "\n",
    "solution = []\n",
    "\n",
    "for i in range(num_steps):\n",
    "    subtask = steps[i]\n",
    "    sys_q = f\"\"\"\n",
    "    There is a common sense question. I need you to solve it and give an answer.\n",
    "    \n",
    "    Here is the problem:\\n{question}\n",
    "    \n",
    "    The choices are: {choices}\n",
    "\n",
    "I have broken this  problem down into a series of smaller problems. I will assign you sub-problems one by one, and provide the results of the previous sub-problems as a reference for your reasoning.\n",
    "Please solve the problem and respond according to logic.\n",
    "\"\"\"  \n",
    "\n",
    "    subask = f\"\"\"\\nThe sub-problem to solve now is xxx: {subtask}\n",
    "Based on the information above, please provide a concise and clear answer\"\"\"\n",
    "    \n",
    "    if len(solution)==0:\n",
    "        # 第一个子问题\n",
    "        query = subask\n",
    "        Q = [{'role':'system', 'content':sys_q},\n",
    "            {'role':'user', 'content':query},]\n",
    "        for n in range(N):  # 一个子问题提问N次,获取N个解\n",
    "            result = askChatGPT(Q, model=GPT_MODEL, temperature=1)\n",
    "            eval_Q = Q + [{'role':'assistant', 'content':result}]\n",
    "            eval_Q = eval_Q + [{'role':'user', 'content':\"Please provide a confidence rating for the accuracy of this solution, on a scale from 1 to 5. Only output the number.\"}]\n",
    "            score = askChatGPT(eval_Q, model=GPT_MODEL, temperature=1)\n",
    "            score = int(score)\n",
    "            \n",
    "            solution.append((score, [result]))  # 维护一整条推理路径\n",
    "            \n",
    "        solution = sorted(solution, key=lambda x: x[0])\n",
    "        solution = solution[:M]  # 剪枝\n",
    "    else:\n",
    "        temp_solution = []\n",
    "        for m in range(M):  # 因为剪枝动态维护M个推理路径\n",
    "            answersSoFar = f\"\"\"\\nSo far, the answers to the preceding sub-problems are as follows: The format is Sub-problem-Id: xxx; Sub-problem: xxx; Answer: xxx.\"\"\"\n",
    "            for index, value in enumerate(solution[m][1]):\n",
    "                try:\n",
    "                    answersSoFar += f\"\"\"\\nSub-problem-Id: {index}; Sub-problem: {steps[index]}; Answer: {value}.\"\"\"\n",
    "                except:\n",
    "                    print('warning')\n",
    "                    print(index)\n",
    "                    print(len(solution[m][1]))\n",
    "                    print(len(steps))\n",
    "                    sys.exit(0)\n",
    "            query = answersSoFar+subask\n",
    "            Q = [{'role':'system', 'content':sys_q},\n",
    "                 {'role':'user', 'content':query},]\n",
    "            for n in range(N):  # 一个子问题提问N次,获取N个解\n",
    "                result = askChatGPT(Q, model=GPT_MODEL, temperature=1)\n",
    "                eval_Q = Q + [{'role':'assistant', 'content':result}]\n",
    "                eval_Q = eval_Q + [{'role':'user', 'content':\"Please provide a confidence rating for the accuracy of this solution, on a scale from 1 to 5. Only output the number.\"}]\n",
    "                score = askChatGPT(eval_Q, model=GPT_MODEL, temperature=1)\n",
    "                score = int(score)\n",
    "                \n",
    "                temp_solution.append((solution[m][0]+score, solution[m][1]+[result]))  # 路径score累加\n",
    "        \n",
    "        # print(len(temp_solution))  # 此时temp_solution中应该有M*N种推理路径\n",
    "        solution = sorted(temp_solution, key=lambda x: x[0])\n",
    "        solution = solution[:M]  # 剪枝 M*N->M\n",
    "\n",
    "print(len(solution))\n",
    "printSeq(solution)\n",
    "# 用额外的一次query再问一下最终的答案\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 从M个路径里挑一个最好的来问,也可以问完之后再评估一下选最好的答案\n",
    "\n",
    "user_q = f\"\"\"There is a common sense question:\\n{question}\n",
    "\n",
    "I have broken this math problem down into a series of smaller problems and each sub-problem is solved.\n",
    "The sub-problems and their corresponding answers are as follows. (Format: Sub-problem-Id: xxx; Sub-problem: xxx; Answer: xxx.)\"\"\"\n",
    "\n",
    "for index, value in enumerate(solution[0][1]):  # 这里仅仅使用了最终得分最高的1条路径来总结得到final answer\n",
    "    user_q += f\"\"\"\\nSub-problem-Id: {index}; Sub-problem: {steps[index]}; Answer: {value}.\"\"\"\n",
    "\n",
    "Q.append({'role':'user', 'content':f\"\"\"Now that all the sub-problems have been solved, so what is the final answer?\n",
    "Onle give the letter of the correct answer.\"\"\"})\n",
    "finalResult = askChatGPT(Q, model=GPT_MODEL, temperature=1)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'E'"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "finalResult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'E'"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "gold_answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "correct\n"
     ]
    }
   ],
   "source": [
    "if finalResult == gold_answer:\n",
    "    print('correct')\n",
    "    success = True  # 任务未受中断,完整地结束了,所以标记为成功\n",
    "elif finalResult != gold_answer:\n",
    "    print('error')\n",
    "    success = True  # 任务未受中断,完整地结束了,所以标记为成功                       "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
